{-# LANGUAGE TupleSections #-}
{-# OPTIONS_GHC -Wno-x-partial #-}
-- | This module defines code for generating termination constraints.

module Language.Haskell.Liquid.Constraint.Termination (
  TCheck(..)
, mkTCheck
, doTermCheck
, makeTermEnvs
, makeDecrIndex
, checkIndex
, recType
, unOCons
, consCBSizedTys
, consCBWithExprs
) where

import           Data.Maybe ( fromJust, catMaybes, mapMaybe )
import qualified Data.List                          as L
import qualified Data.HashSet                       as S
import qualified Data.Traversable                   as T
import qualified Data.HashMap.Strict                as M
import           Control.Monad ( foldM, forM )
import           Control.Monad.State ( gets )
import           Text.PrettyPrint.HughesPJ ( (<+>), text )
import qualified Language.Haskell.Liquid.GHC.Misc                    as GM
import qualified Language.Fixpoint.Types            as F
import           Language.Haskell.Liquid.Constraint.Types (CG, CGInfo (..), CGEnv, makeRecInvariants)
import           Language.Haskell.Liquid.Constraint.Monad (addWarning)
import           Language.Haskell.Liquid.Constraint.Env (setTRec)
import           Language.Haskell.Liquid.Constraint.Template ( Template(..), unTemplate, varTemplate, safeFromAsserted, extender )
import           Language.Haskell.Liquid.Types.Errors
import           Language.Haskell.Liquid.Types.Fresh
import           Language.Haskell.Liquid.Types.RType
import           Language.Haskell.Liquid.Types.RTypeOp
import           Language.Haskell.Liquid.Types.Types
import           Language.Haskell.Liquid.Types.RefType (isDecreasing, makeDecrType, makeLexRefa, makeNumEnv)
import           Language.Haskell.Liquid.Misc (safeFromLeft, replaceN, (<->), zip4, safeFromJust, fst5)
import           Language.Haskell.Liquid.UX.Config
import qualified Liquid.GHC.API as GHC


data TCheck = TerminationCheck | StrataCheck | NoCheck
  deriving Show

mkTCheck :: Bool -> Bool -> TCheck
mkTCheck tc is
  | not is    = StrataCheck
  | tc        = TerminationCheck
  | otherwise = NoCheck

doTermCheck :: Config -> GHC.Bind GHC.Var -> CG Bool
doTermCheck cfg bind = do
  lazyVs    <- gets specLazy
  termVs    <- gets specTmVars
  let skip   = any (\x -> S.member x lazyVs || nocheck x) xs
  let chk    = not (structuralTerm cfg) || any (`S.member` termVs) xs
  return     $ chk && not skip
  where
    nocheck  = if typeclass cfg then GM.isEmbeddedDictVar else GM.isInternal
    xs       = GHC.bindersOf bind

makeTermEnvs :: CGEnv -> [(GHC.Var, [F.Located F.Expr])] -> [(GHC.Var, GHC.CoreExpr)]
             -> [SpecType] -> [SpecType]
             -> [CGEnv]
makeTermEnvs γ xtes xes ts ts' = setTRec γ . zip xs <$> rts
  where
    vs   = zipWith collectArgs' ts ces
    syms = fst5 . bkArrowDeep <$> ts
    syms' = fst5 . bkArrowDeep <$> ts'
    sus' = zipWith mkSub syms syms'
    sus  = zipWith mkSub syms ((F.symbol <$>) <$> vs)
    ess  = (\x -> safeFromJust (err x) (x `L.lookup` xtes)) <$> xs
    tes  = zipWith (\su es -> F.subst su <$> es)  sus ess
    tes' = zipWith (\su es -> F.subst su <$> es)  sus' ess
    rss  = zipWith makeLexRefa tes' <$> (repeat <$> tes)
    rts  = zipWith (addObligation OTerm) ts' <$> rss
    (xs, ces)    = unzip xes
    mkSub ys ys' = F.mkSubst [(x, F.EVar y) | (x, y) <- zip ys ys']
    collectArgs' = collectArguments . length . ty_binds . toRTypeRep
    err x        = "Constant: makeTermEnvs: no terminating expression for " ++ GM.showPpr x

addObligation :: Oblig -> SpecType -> RReft -> SpecType
addObligation o t r  = mkArrow αs πs xts $ RRTy [] r o t2
  where
    (αs, πs, t1) = bkUniv t
    ((xs, is, ts, rs), t2) = bkArrow t1
    xts              = zip4 xs is ts rs

--------------------------------------------------------------------------------
-- | TERMINATION TYPE ----------------------------------------------------------
--------------------------------------------------------------------------------

makeDecrIndex :: (GHC.Var, Template SpecType) -> CG (Maybe Int)
makeDecrIndex (x, Assumed t)
  = do dindex <- makeDecrIndexTy x t
       case dindex of
         Left msg -> addWarning msg >> return Nothing
         Right i -> return $ Just i
makeDecrIndex (x, Asserted t)
  = do dindex <- makeDecrIndexTy x t
       case dindex of
         Left msg -> addWarning msg >> return Nothing
         Right i  -> return $ Just i
makeDecrIndex _ = return Nothing

makeDecrIndexTy :: GHC.Var -> SpecType -> CG (Either (TError t) Int)
makeDecrIndexTy x st
  = do autosz <- gets autoSize
       return $ case dindex autosz of
         Nothing -> Left msg
         Just i  -> Right i
    where
       msg  = ErrTermin (GHC.getSrcSpan x) [F.pprint x] (text "No decreasing parameter")
       trep = toRTypeRep $ unOCons st
       ts   = ty_args trep
       cenv = makeNumEnv ts

       p autosz t = isDecreasing autosz cenv t
       dindex     autosz = L.findIndex (p autosz) ts

recType :: F.Symbolic a
        => S.HashSet GHC.TyCon
        -> (([a], Maybe Int), (t, Maybe Int, SpecType))
        -> SpecType
recType _       ((_, Nothing), (_, Nothing, t)) = t
recType autoenv ((vs, indexc), (_, index, t))
  = makeRecType autoenv t v dxt index
  where v    = (vs !!)  <$> indexc
        dxt  = (xts !!) <$> index
        xts  = zip (ty_binds trep) (ty_args trep)
        trep = toRTypeRep $ unOCons t

checkIndex :: (GHC.NamedThing t, PPrint t, PPrint a)
           => (t, [a], Template (RType c tv r), Maybe Int)
           -> CG (Maybe (RType c tv r))
checkIndex (_,  _, _, Nothing   ) = return Nothing
checkIndex (x, vs, t, Just index) = safeLogIndex msg1 vs index >> safeLogIndex msg2 ts index
    where
       loc   = GHC.getSrcSpan x
       ts    = ty_args $ toRTypeRep $ unOCons $ unTemplate t
       msg1  = ErrTermin loc [xd] (text "No decreasing" <+> F.pprint index <-> text "-th argument on" <+> xd <+> text "with" <+> F.pprint vs)
       msg2  = ErrTermin loc [xd] (text  "No decreasing parameter")
       xd    = F.pprint x

makeRecType :: (Enum a1, Eq a1, Num a1, F.Symbolic a)
            => S.HashSet GHC.TyCon
            -> SpecType
            -> Maybe a
            -> Maybe (F.Symbol, SpecType)
            -> Maybe a1
            -> SpecType
makeRecType autoenv t vs dxs is
  = mergecondition t $ fromRTypeRep $ trep {ty_binds = xs', ty_args = ts'}
  where
    (xs', ts') = unzip $ replaceN (fromJust is) (safeFromLeft "makeRecType" $ makeDecrType autoenv vdxs) xts
    vdxs       = liftA2 (,) vs dxs
    xts        = zip (ty_binds trep) (ty_args trep)
    trep       = toRTypeRep $ unOCons t

unOCons :: RType c tv r -> RType c tv r
unOCons (RAllT v t r)      = RAllT v (unOCons t) r
unOCons (RAllP p t)        = RAllP p $ unOCons t
unOCons (RFun x i tx t r)  = RFun x i (unOCons tx) (unOCons t) r
unOCons (RRTy _ _ OCons t) = unOCons t
unOCons t                  = t

mergecondition :: RType c tv r -> RType c tv r -> RType c tv r
mergecondition (RAllT _ t1 _)       (RAllT v t2 r2)        = RAllT v (mergecondition t1 t2) r2
mergecondition (RAllP _ t1)         (RAllP p t2)           = RAllP p (mergecondition t1 t2)
mergecondition (RRTy xts r OCons t1) t2                    = RRTy xts r OCons (mergecondition t1 t2)
mergecondition (RFun _ _ t11 t12 _) (RFun x2 i t21 t22 r2) = RFun x2 i (mergecondition t11 t21) (mergecondition t12 t22) r2
mergecondition _                     t                     = t

safeLogIndex :: Error -> [a] -> Int -> CG (Maybe a)
safeLogIndex err ls n
  | n >= length ls = addWarning err >> return Nothing
  | otherwise      = return $ Just $ ls !! n

-- RJ: AAAAAAARGHHH!!!!!! THIS CODE IS HORRIBLE!!!!!!!!!
consCBSizedTys :: (Bool -> CGEnv -> (GHC.Var, GHC.CoreExpr, Template SpecType) -> CG (Template SpecType)) ->
                  CGEnv -> [(GHC.Var, GHC.CoreExpr)] -> CG CGEnv
consCBSizedTys consBind γ xes
  = do ts'      <- forM xes $ \(x, e) -> varTemplate γ (x, Just e)
       autoenv  <- gets autoSize
       ts       <- forM ts' $ T.mapM refreshArgs
       let vs    = zipWith collectArgs' ts es
       is       <- mapM makeDecrIndex (zip vars ts) >>= checkSameLens
       let xeets = zipWith (\v i -> [((v,i), x) | x <- zip3 vars is $ map unTemplate ts]) vs is
       _        <- mapM checkIndex (zip4 vars vs ts is) >>= checkEqTypes
       let rts   = (recType autoenv <$>) <$> xeets
       γ'       <- foldM extender γ (zip vars ts)
       let γs    = zipWith makeRecInvariants [γ' `setTRec` zip vars rts' | rts' <- rts] (filter (not . noMakeRec) <$> vs)
       mapM_ (uncurry $ consBind True) (zip γs (zip3 vars es ts))
       return γ'
  where
       noMakeRec      = if allowTC then GM.isEmbeddedDictVar else GM.isPredVar
       allowTC        = typeclass (getConfig γ)
       (vars, es)     = unzip xes
       dxs            = F.pprint <$> vars
       collectArgs'   = collectArguments . length . ty_binds . toRTypeRep . unOCons . unTemplate
       checkEqTypes :: [Maybe SpecType] -> CG [SpecType]
       checkEqTypes x = checkAllVsHead err1 toRSort (catMaybes x)
       err1           = ErrTermin loc dxs $ text "The decreasing parameters should be of same type"
       checkSameLens :: [Maybe Int] -> CG [Maybe Int]
       checkSameLens  = checkAllVsHead err2 length
       err2           = ErrTermin loc dxs $ text "All Recursive functions should have the same number of decreasing parameters"
       loc            = GHC.getSrcSpan (head vars)

       checkAllVsHead :: Eq b => Error -> (a -> b) -> [a] -> CG [a]
       checkAllVsHead _   _ []          = return []
       checkAllVsHead err f (x:xs)
         | all (== f x) (f <$> xs) = return (x:xs)
         | otherwise               = addWarning err >> return []

-- See Note [Shape of normalized terms] in
-- Language.Haskell.Liquid.Transforms.ANF
collectArguments :: Int -> GHC.CoreExpr -> [GHC.Var]
collectArguments n e = take n (vs' ++ vs)
  where
    (vs', e')        = collectValBinders' $ snd $ GHC.collectTyBinders e
    vs               = fst $ GHC.collectBinders $ ignoreLetBinds e'

collectValBinders' :: GHC.CoreExpr -> ([GHC.Var], GHC.CoreExpr)
collectValBinders' = go []
  where
    go tvs (GHC.Lam b e) | GHC.isTyVar b = go tvs     e
    go tvs (GHC.Lam b e) | GHC.isId    b = go (b:tvs) e
    go tvs (GHC.Tick _ e)            = go tvs e
    go tvs e                         = (reverse tvs, e)

ignoreLetBinds :: GHC.Expr b -> GHC.Expr b
ignoreLetBinds (GHC.Let (GHC.NonRec _ _) e') = ignoreLetBinds e'
ignoreLetBinds e = e

consCBWithExprs :: (Bool -> CGEnv -> (GHC.Var, GHC.CoreExpr, Template SpecType) -> CG (Template SpecType)) ->
                   CGEnv -> [(GHC.Var, GHC.CoreExpr)] -> CG CGEnv
consCBWithExprs consBind γ xes
  = do ts0      <- forM xes $ \(x, e) -> varTemplate γ (x, Just e)
       texprs   <- gets termExprs
       let xtes  = mapMaybe (`lookup'` texprs) xs
       let ts    = safeFromAsserted err <$> ts0
       ts'      <- mapM refreshArgs ts
       let xts   = zip xs (Asserted <$> ts')
       γ'       <- foldM extender γ xts
       let γs    = makeTermEnvs γ' xtes xes ts ts'
       mapM_ (uncurry $ consBind True) (zip γs (zip3 xs es (Asserted <$> ts')))
       return γ'
  where (xs, es) = unzip xes
        lookup' k m = (k,) <$> M.lookup k m
        err      = "Constant: consCBWithExprs"
